import argparse
import os
import json
from torchvision import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.nn.functional as F
import dataset_processing
import torch.optim as optim
from torch.autograd import Variable
import torch
import torchvision.models as models

import torchlib

from nets import resnext101, densenet161
from utils import calculate_metrics
import numpy as np

DATA_PATH = 'data/AIDA25' # Path to the dataset
TRAIN_DATA = 'train_img'
VALIDATION_DATA = 'train_img'
TRAIN_IMG_FILE = 'train_img.txt'
TRAIN_LABEL_FILE = 'train_label.txt'
VALIDATION_IMG_FILE = 'validation_img.txt'
VALIDATION_LABEL_FILE = 'validation_label.txt'

### Choose the net for training
choose_net = 'densenet161'
# choose_net = 'resnext101'

save_path = os.path.join('ckpt', f'{choose_net}_AIDA25')
log_path = os.path.join('log', f'{choose_net}_AIDA25')
if not os.path.isdir(save_path):
        os.makedirs(save_path)
if not os.path.isdir(log_path):
        os.makedirs(log_path)

NLABELS = 27
batch_size = 16
lr = 1e-3
max_epoch_number = 100
test_freq = 50
save_freq = 1


### Dataloader
train_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(),
        transforms.RandomHorizontalFlip(),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])
test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor()
    ])

dset_train = dataset_processing.DatasetProcessing(
    DATA_PATH, TRAIN_DATA, TRAIN_IMG_FILE, TRAIN_LABEL_FILE, train_transform)

dset_validation = dataset_processing.DatasetProcessing(
    DATA_PATH, VALIDATION_DATA, VALIDATION_IMG_FILE, VALIDATION_LABEL_FILE, test_transform)

train_loader = DataLoader(dset_train,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          drop_last=True
                         )

validation_loader = DataLoader(dset_validation,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          drop_last=True
                         )
                         
### Initialize model
use_gpu = torch.cuda.is_available()

if choose_net == 'densenet161':
    model = densenet161(NLABELS)
if choose_net == 'resnext101':
    model = resnext101(NLABELS)

if use_gpu:
    model = model.cuda()

print(model)   
optimizer = optim.Adam(model.parameters(), lr=lr)

# load checkpoint
try:
    ckpt = torchlib.load_checkpoint(save_path)
    start_ep = ckpt['epoch']
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optimizer'])
except:
    print(' [*] No checkpoint!')
    start_ep = 0

criterion = nn.BCELoss()

iteration = 0
while True:
    batch_losses = []
    for imgs, targets in train_loader:
        imgs, targets = imgs.cuda(), targets.cuda()

        optimizer.zero_grad()

        model_result = model(imgs)
        loss = criterion(model_result, targets.type(torch.float))

        batch_loss_value = loss.item()
        loss.backward()
        optimizer.step()
        batch_losses.append(batch_loss_value)

        if iteration % test_freq == 0:
            model.eval()
            with torch.no_grad():
                model_result = []
                targets = []
                for imgs, batch_targets in validation_loader:
                    imgs = imgs.cuda()
                    model_batch_result = model(imgs)
                    model_result.extend(model_batch_result.cpu().numpy())
                    targets.extend(batch_targets.cpu().numpy())

            result = calculate_metrics(np.array(model_result), np.array(targets))
            print("epoch:{:2d} iter:{:3d} test: "
                    "micro f1: {:.3f} "
                    "macro f1: {:.3f} "
                    "samples f1: {:.3f} "
                    "weighted f1: {:.3f}".format(start_ep, iteration,
                                                result['micro/f1'],
                                                result['macro/f1'],
                                                result['samples/f1'],
                                                result['weighted/f1']))
            with open(os.path.join(log_path,f"{choose_net}_metrics.txt"), "a", encoding="utf-8") as file:
                file.write("epoch:{:2d} iter:{:3d} test: "
                        "micro precision: {:.3f} "
                        "macro precision: {:.3f} "
                        "samples precision: {:.3f} "
                        "weighted precision: {:.3f} "
                        "micro recall: {:.3f} "
                        "macro recall: {:.3f} "
                        "samples recall: {:.3f} "
                        "weighted recall: {:.3f} "
                        "micro f1: {:.3f} "
                        "macro f1: {:.3f} "
                        "samples f1: {:.3f} "
                        "weighted f1: {:.3f}".format(start_ep, iteration,
                                                    result['micro/precision'],
                                                    result['macro/precision'],
                                                    result['samples/precision'],
                                                    result['weighted/precision'],
                                                    result['micro/recall'],
                                                    result['macro/recall'],
                                                    result['samples/recall'],
                                                    result['weighted/recall'],
                                                    result['micro/f1'],
                                                    result['macro/f1'],
                                                    result['samples/f1'],
                                                    result['weighted/f1']) + "\n")

            model.train()

        iteration += 1

    loss_value = np.mean(batch_losses)
    print("epoch:{:2d} iter:{:3d} train: loss:{:.3f}".format(start_ep, iteration, loss_value))
    with open(os.path.join(log_path,f"{choose_net}_loss.txt"), "a", encoding="utf-8") as file:
                file.write("{:2d},{:3d},{:.3f}".format(start_ep, iteration, loss_value) + "\n")
    if start_ep % save_freq == 0:
        torchlib.save_checkpoint({'epoch': start_ep + 1,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict()},
                        '%s/Epoch_(%d).ckpt' % (save_path, start_ep + 1),
                         max_keep=10
                        )
    start_ep += 1
    if max_epoch_number < start_ep:
        break
